"""
30 nodes situation

"""
from torch import nn
import torch, os, json
import matplotlib.pyplot as plt

from tqdm import tqdm
import numpy as np
import torch_geometric as pyg
from torch_geometric.data import InMemoryDataset, Data



class Encoder(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(Encoder, self).__init__()

        self.output_size = output_size

        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, output_size)
        )

    def forward(self, x):


        x = self.layers(x)
        return x


class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(Decoder, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Mish(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self, x):


        return self.layers(x)


class CN(nn.Module):
    def __init__(self):
        super(CN, self).__init__()
        object_dim = 4  # node features
        relation_dim = 1  # edge features
        effect_dim = 50
        x_external_dim = 0
        self.encoder_model = Encoder(2 * object_dim + relation_dim, effect_dim, 150)
        self.decoder_model = Decoder(object_dim + effect_dim + x_external_dim, 100)

    def forward(self, objects, sender_relations, receiver_relations, relation_info):

        senders = torch.matmul(torch.t(sender_relations.float()), objects.float())
        receivers = torch.matmul(torch.t(receiver_relations.float()), objects.float())
        m = torch.cat((receivers, senders, relation_info), 1)
        effects = self.encoder_model(m.float())
        effect_receivers = torch.matmul(receiver_relations.float(), effects)
        aggregation_result = torch.cat((objects, effect_receivers), 1)
        predicted = self.decoder_model(aggregation_result)
        return predicted


def one_d_data_process(data_path, split):
    data_list = []
    total_simulation = 1000
    total_time_step = 10000
    node_size = 32
    n_relations = (node_size - 1) * 2
    dt = 1e-4
    data = np.memmap(os.path.join(data_path, f"{split}_position.dat"), dtype=np.float32, mode="c",
                                shape=(total_simulation, total_time_step, node_size, 2)) #sample_size, time_size, node_size, 3
   


    for simulation in range(total_simulation):
        for t in range(2, total_time_step):
            temp_x_1 = data[simulation, t-2, :, 0]
            temp_x_2 = data[simulation, t -1, :, 0] 

            radius = np.ones((node_size, 1))

            temp_x_1 = temp_x_1.reshape(node_size, 1)
            temp_x_2 = temp_x_2.reshape(node_size, 1)
            temp_x_1[0] = -1
            temp_x_1[node_size - 1] = 201
            temp_x_2[0] = -1
            temp_x_2[node_size - 1] = 201

            velocity = (temp_x_2 - temp_x_1) / dt
            velocity = velocity.reshape(node_size, 1)
            temp_x = np.concatenate((temp_x_1, temp_x_2, velocity, radius), axis=1)
            
            target_velocity = data[simulation, t, :, 1]

            relation_distance = temp_x_2[1:, 0] - temp_x_2[:-1, 0]
            relation_distance2 = np.zeros((node_size - 1, 2))
            relation_distance2[:, 0] = relation_distance
            relation_distance2[:, 1] = -relation_distance
            relation_distance2 = relation_distance2.flatten()  # distance feature
            relation_distance2 = relation_distance2.reshape(2 * (node_size - 1), 1)

            edge_features = relation_distance2
           

            graph = Data(x=torch.from_numpy(temp_x).float(), edge_attr=torch.from_numpy(edge_features).float(),  y=torch.from_numpy(target_velocity).float())
            data_list.append(graph)


    return data_list


def train(params, simulator, train_loader, previous_step, sender_relations, receiver_relations):
    loss_fn = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(simulator.parameters(), lr=params["lr"])
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1 ** (1 / 5e6))
    train_loss_list = []
    total_step = 0
    sender_relations, receiver_relations = torch.from_numpy(sender_relations), torch.from_numpy(receiver_relations)

    for i in range(params["epoch"]):
        simulator.train()
        progress_bar = tqdm(train_loader, desc=f"Epoch {i}")
        total_loss = 0
        batch_count = 0
        for data in progress_bar:
            optimizer.zero_grad()
            sender_relations = sender_relations.cuda()
            receiver_relations = receiver_relations.cuda()
            data = data.cuda()
            pred = simulator(data.x, sender_relations, receiver_relations, data.edge_attr) #sender_relations, receiver_relations,
            data.y = data.y.reshape(-1, 1)
            loss = loss_fn(pred, data.y)
            loss.backward()
            optimizer.step()
            scheduler.step()
            total_loss += loss.item()
            batch_count += 1
            progress_bar.set_postfix({"loss": loss.item(), "avg_loss": total_loss / batch_count, "lr": optimizer.param_groups[0]["lr"]})
            train_loss_list.append(loss.item())
            if total_step % params["eval_interval"] == 0:
                print(f"avg_loss: {total_loss / batch_count}")
            total_step += 1

            # save model
            if total_step % params["save_interval"] == 0:
                torch.save(
                    {
                        "model": simulator.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                    },
                    os.path.join(model_path, f"{DATASET_NAME}_{total_step + previous_step}.pt")
                )

    fig, ax = plt.subplots(num=1, clear=True)
    # Add some text for labels, title and custom x-axis tick labels, etc.
    ax.set_ylabel('loss')
    ax.set_title(f'training loss')
    ax.set_ylim(0, 0.00001)
    ax.plot(train_loss_list, label=f'loss')
    ax.legend(loc='upper right')
    folder_path = f'./img/training_loss_{DATASET_NAME}/'
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)
        plt.savefig(f'{folder_path}/loss.png')
    else:
        plt.savefig(f'{folder_path}/loss.png')

params = {
    "epoch": 2000,
    "batch_size": 100,
    "lr": 1e-4,
    "save_interval": 100000,
    "eval_interval": 2000000,
}


if __name__ == '__main__':
    DATASET_NAME = "CN_30"
    model_path = f"./models/{DATASET_NAME}"

    if not os.path.exists(model_path):
        os.makedirs(model_path)
    folder_path = f'./data_IN' #change to the data path

    train_dataset = one_d_data_process(folder_path, "train")
    train_loader = pyg.loader.DataLoader(train_dataset, batch_size=params["batch_size"], shuffle=True, pin_memory=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    simulator = CN().to(device)

    previous_step = 0

    batch_size = params['batch_size']
    n_objects = 32
    n_relations = (n_objects - 1) * 2

    # Construct receiver_relations and sender_relations
    receiver_relations = np.zeros((n_objects, n_relations), dtype=float)
    sender_relations = np.zeros((n_objects, n_relations), dtype=float)
    for i in range(1, n_objects - 1):  # assign the non-boundary nodes first (node1 to node 28)
        receiver_relations[i, 2*i - 2] = 1.0
        receiver_relations[i, 2*i + 1] = 1.0

        sender_relations[i, 2*i] = 1.0
        sender_relations[i, 2*i -1] = 1.0

    # left boundary
    receiver_relations[0, 1] = 1.0
    sender_relations[0, 0] = 1.0

    # right boundary

    receiver_relations[n_objects - 1, n_relations - 2] = 1.0

    sender_relations[n_objects - 1, n_relations - 1] = 1.0


    sender_relations_batch = np.zeros((batch_size * n_objects, batch_size * n_relations))
    receiver_relations_batch = np.zeros((batch_size * n_objects, batch_size * n_relations))
    for i in range(batch_size):
        sender_relations_batch[n_objects*i:n_objects*(i+1), n_relations*i:n_relations*(i+1)] = sender_relations
        receiver_relations_batch[n_objects*i:n_objects*(i+1), n_relations*i:n_relations*(i+1)] = receiver_relations
    train(params, simulator, train_loader, previous_step, sender_relations_batch, receiver_relations_batch)


